#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Jan 10 2022
Compare offline modeled dry deposition velocities with measured values - for paper
@author: arifeinberg
"""

#%%
#import os
#os.chdir('/Users/arifeinberg/target2/fs03/d0/arifein/python/')

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import csv
from scipy import stats
import cartopy.crs as ccrs
import cartopy.feature as cf
import scipy.io as sio

#%% Dry deposition functions
def summary_stats(obs, model):
    """Summary statistics for observation model comparison
    
    Parameters
    ----------
    obs : ndarray
         Observation values
    model : ndarray
         Model values 
    """
    
    # Calculate statistics for agreement
    n = len(obs)
    # mean observation
    mean_O = obs.mean()
    
    # mean model
    mean_M = model.mean()
    
    # mean observation
    median_O = np.median(obs)
    
    # mean model
    median_M = np.median(model)
    
    # mean normalized bias - skewed by one data point at point 16
    mnb = 1./n * sum((model - obs) / obs) * 100.0
    
    # mean normalized error  - skewed by one data point at point 16
    mne = 1./n * sum(abs(model - obs) / obs) * 100.0
    
    # correlation (Pearson)
    pcorr, _ = stats.pearsonr(model, obs)
    
    # correlation (Spearman)
    scorr, _ = stats.spearmanr(model, obs)
    
    # RMSE
    rmse = np.sqrt(1./n * sum((model - obs)**2))
    
    # MAE
    mae = np.mean(abs(model - obs))
    
    # MdAE
    mdae = np.median(abs(model - obs))
    
    # Create dictionary of stats
    stats_dict = {'n': n, 'mean_O': mean_O, 'mean_M': mean_M, 
         'median_O': median_O, 'median_M': median_M,          
         'mnb': mnb, 'mne': mne, 'pcorr': pcorr, 'scorr': scorr,
         'rmse': rmse, 'mae': mae,'mdae': mdae}
        
    return stats_dict

#%% Load data from offline dry deposition code
F0 = 3e-5
F0_am = 0.2

data_mat = sio.loadmat('misc_Data/LT_adj_deposition_F0_' + str(F0) + '_' + str(F0_am) +  '_v5.mat')
# Non-Amazon sites
DV = data_mat['DV_lt']
obs_lat = data_mat['obs_lat'][0,:]
obs_lon = data_mat['obs_lon'][0,:]
site_types = data_mat['site_types'][:]

# Load observed dry deposition velocities
source = '../obs_datasets/dry_dep/SI_Forest_Hg_uptake_database.csv'
data_dd_f = pd.read_csv(source)

df_dd = data_dd_f.loc[~data_dd_f['Litterfall Dry Deposition velocity (cm s-1)'].isna()].reset_index()

# Load dry deposition velocities
DV_obs_litter_a_LC = df_dd['Litterfall Dry Deposition velocity (cm s-1)'].values
DV_obs_total_LC = df_dd['Total foliar uptake deposition velocity (cm s-1)'].values
Region = df_dd['Region'].values

# Correct observed dry deposition velocity for local temperature and pressure
fn_met = '../GEOS-Chem_runs/run0100/OutputDir/GEOSChem.StateMet.alltime_m.nc4'
# alternative file in directory: run0100_StateMet_T_surfP.nc4
ds_met = xr.open_dataset(fn_met)
    
surf_T = ds_met.Met_TS.mean("time").values # surface temperature in T
surf_pres = ds_met.Met_PSC2WET.mean("time").values * 100.0 # surface pressure in Pa
lon = ds_met.lon
lat = ds_met.lat

# site_T = np.zeros(len(obs_lat))
# site_P = np.zeros(len(obs_lat))

# # loop through sites, find mean surf T and P
# for i in range(len(obs_lat)):
#     # find lon and lat indices of sites
#     lat_i = np.argmin(np.abs(np.array(lat)-obs_lat[i]))
#     lon_i = np.argmin(np.abs(np.array(lon)-obs_lon[i]))
#     if site_types[i].flatten() is np.array([6]): # use coordinates from Manaus
#         print(i)
#         lat_M = -3.12 # Manaus lat
#         lon_M = -60.02 # Manaus lon
#         # Find pressure/temperature correction factor for Manaus
#         lat_i = np.argmin(np.abs(np.array(lat)-lat_M))
#         lon_i = np.argmin(np.abs(np.array(lon)-lon_M))
    
#     site_T[i] = surf_T[lat_i, lon_i]
#     site_P[i] = surf_pres[lat_i, lon_i]
    
stdpressure = 101325 # Pascals
stdtemp = 273.15 # Kelvins

# # make correction on deposition velocity for local conditions
# corr = site_T / stdtemp * stdpressure / site_P
# DV_obs_litter_LC = DV_obs_litter * site_T / stdtemp * stdpressure / site_P
# DV_obs_total_LC = DV_obs_total * site_T / stdtemp * stdpressure / site_P

# Separate Amazon and non-Amazon litterfall
DV_amazon_obs = DV_obs_litter_a_LC[Region=='South America']
DV_obs_litter_LC = DV_obs_litter_a_LC[Region!='South America']
# take mean of model before separating
DV_m = DV.mean(axis=1)
DV_am = DV_m[Region=='South America']
DV_else = DV_m[Region!='South America']

# calculate addition from Amazon throughfall - wetdep
Hg0_Manaus = 1.04 # ng/m^3
amazon_TF = 53.8 # ug/m2/yr, plus additional throughfall-wetdep

# Find pressure/temperature correction factor for Manaus
lat_M = -3.12 # Manaus lat
lon_M = -60.02 # Manaus lon
lat_i = np.argmin(np.abs(np.array(lat)-lat_M))
lon_i = np.argmin(np.abs(np.array(lon)-lon_M))

# correction factor for Manaus [Hg0]
CF_M = surf_T[lat_i, lon_i] / stdtemp * stdpressure / surf_pres[lat_i, lon_i]

# unit conversion
ng_ug = 1000. # ng/ug
cm_m = 100. # cm/m
s_yr = 365.2425 * 24. * 60. * 60. # s in yr
unit_conv = ng_ug * cm_m / s_yr

#throughfall from Amazon_TF 
DV_amazon_TF = amazon_TF / Hg0_Manaus * CF_M * unit_conv # cm/s
DV_amazon_obs_tot = DV_amazon_obs + DV_amazon_TF # cm/s

# Plotting
figure_drydep,  axes1 = plt.subplots(1, 1, figsize=[8,6])
figure_drydep.subplots_adjust(bottom = 0.15, left=0.15)
#-all-

# Amazon plot
am_mod = DV_am
am_obs = DV_amazon_obs
a3=axes1.plot(am_obs, am_mod, 'o', mew =1.5,zorder=3,
          markerfacecolor='#b2df8a', color='#b2df8a',
          label='Amazon Litterfall (Fostier et al., 2015)')
am_mod_med = np.median(am_mod)
am_obs_med = np.median(am_obs)

error_am_mod = [[am_mod_med - np.quantile(am_mod, 0.25)], 
                  [np.quantile(am_mod, 0.75) -am_mod_med]]
error_am_obs = [[am_obs_med - np.quantile(am_obs, 0.25)], 
                  [np.quantile(am_obs, 0.75) -am_obs_med]]
am_obs_med_tot = np.median(DV_amazon_obs_tot)
error_am_obs_tot =[[2 *(am_obs_med - np.quantile(am_obs, 0.25))], 
                  [2*(np.quantile(am_obs, 0.75) -am_obs_med)]] # assume double errors

# Total plot
tot_mod = DV.mean(axis=1)[~np.isnan(DV_obs_total_LC)]
tot_obs = DV_obs_total_LC[~np.isnan(DV_obs_total_LC)]
a2=axes1.plot(tot_obs, tot_mod, 'o',zorder=2,
            markerfacecolor='#e41a1c', color='#e41a1c',
            label='Litterfall + Throughfall - Wet Dep all')
tot_mod_med = np.median(tot_mod)
tot_obs_med = np.median(tot_obs)
error_tot_mod = [[tot_mod_med - np.quantile(tot_mod, 0.25)], 
                  [np.quantile(tot_mod, 0.75) -tot_mod_med]]
error_tot_obs = [[tot_obs_med - np.quantile(tot_obs, 0.25)], 
                  [np.quantile(tot_obs, 0.75) -tot_obs_med]]


#litter 
litt_mod = DV_else
litt_obs = DV_obs_litter_LC
a1=axes1.plot(litt_obs, litt_mod, 'o', zorder=1, mew =1.5, color='silver',
            markerfacecolor='silver', label='Litterfall all')
litt_mod_med = np.median(litt_mod)
litt_obs_med = np.median(litt_obs)
error_litt_mod = [[litt_mod_med - np.quantile(litt_mod, 0.25)], 
                  [np.quantile(litt_mod, 0.75) -litt_mod_med]]
error_litt_obs = [[litt_obs_med - np.quantile(litt_obs, 0.25)], 
                  [np.quantile(litt_obs, 0.75) -litt_obs_med]]


#-medians-
e5=axes1.errorbar(am_obs_med_tot, am_mod_med, yerr=error_am_mod, xerr=error_am_obs_tot,
            fmt='*', color='k', markerfacecolor='#33a02c', lw=3,
            mew = 2, ms =25, zorder=14,capsize=2,
            label='Amazon Litterfall + Throughfall – Wet Dep median (IQR)')

e4=axes1.errorbar(am_obs_med, am_mod_med, yerr=error_am_mod, xerr=error_am_obs,
            fmt='*', color='k', markerfacecolor='#b2df8a', lw=3,
            mew = 2, ms =25, zorder=13,capsize=2,
            label='Amazon Litterfall median (IQR)')
# Obrist plot
obr_mod = DV[-1].mean()
obr_obs = 0.07230794
e3=axes1.plot(obr_obs, obr_mod, '*', mew =2, ms = 25,
           markerfacecolor='#377eb8', color='k', zorder=12,
           label='Flux Tower (Obrist et al., 2021)')

e2=axes1.errorbar(tot_obs_med, tot_mod_med, yerr=error_tot_mod, xerr=error_tot_obs,
            fmt='*', color='k', markerfacecolor='#e41a1c', lw=3, zorder=11,
            mew = 2, ms =25,capsize=2,
            label='Litterfall + Throughfall - Wet Dep median (IQR)')

e1=axes1.errorbar(litt_obs_med, litt_mod_med, yerr=error_litt_mod, xerr=error_litt_obs,
            fmt='*', color='k', markerfacecolor='silver', lw=3, zorder=10,
            mew = 2, ms =25,capsize=2,
            label='Litterfall median (IQR)')

# print("obrist median obs")
# print(obr_obs)
# print("obrist median mod")
# print(obr_mod)

# print("AMTOT median obs")
# print(am_obs_med_tot)

# print("AM median obs")
# print(am_obs_med)
# print("AM 25%")
# print(np.quantile(am_obs, 0.25))
# print("AM 75%")
# print(np.quantile(am_obs, 0.75))
# print("AM median mod")
# print(am_mod_med)

# print("tot median obs")
# print(tot_obs_med)
# print("tot 25%")
# print(np.quantile(tot_obs, 0.25))
# print("tot 75%")
# print(np.quantile(tot_obs, 0.75))
# print("tot median mod")
# print(tot_mod_med)

# print("litt median obs")
# print(litt_obs_med)
# print("litt 25%")
# print(np.quantile(litt_obs, 0.25))
# print("litt 75%")
# print(np.quantile(litt_obs, 0.75))
# print("litt median mod")
# print(litt_mod_med)

# print("mod 25% tot")
# print(np.quantile(tot_mod, 0.25))
# print("mod 75% tot")
# print(np.quantile(tot_mod, 0.75))

# print("mod 25% litt")
# print(np.quantile(litt_mod, 0.25))
# print("mod 75% litt")
# print(np.quantile(litt_mod, 0.75))

# fix axes
axes1.set_xlabel('Obs dep. velocity (cm s$^{-1}$)', fontsize = 26)
axes1.set_ylabel('Model dep. velocity (cm s$^{-1}$)', fontsize = 26)
ax_min = 0
ax_max = max(max(DV_amazon_obs_tot),max(DV_am))*1.1
axes1.set_xlim([ax_min, ax_max])
axes1.set_ylim([ax_min, ax_max])

l1=axes1.plot([ax_min, ax_max], [ax_min, ax_max], '--', color = 'black', 
           label='1-1 Line',zorder=0)
#sim_name = 'Model (f$_0$ = ' + str(F0) + ') vs. Observations'
sim_name = 'Model BASE (f$_0=10^{-5}$)'

#axes1.set_title(sim_name, fontsize = 18, fontweight='bold');
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
# handles, labels = axes1.get_legend_handles_labels()
# # sort both labels and handles by labels
# labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))

labels = ['1–1 Line', 'Litterfall median', 
    'Total foliar uptake median',
    'Flux tower (Obrist et al., 2021)', 'Amazon litterfall median', 
    'Amazon total foliar uptake median', 
    'Litterfall all points','Total foliar uptake all points',
    'Amazon litterfall (Fostier et al., 2015)']
handles = [l1[0], e1, e2,e3[0], e4, e5, a1[0],a2[0],a3[0]]
#lgnd = axes1.legend(handles,labels,fontsize=18,bbox_to_anchor=(-0.01, 0.32, 1, 0.7), loc='upper left',framealpha=0.95)
figure_drydep.savefig('Figures/offline_dvel_LT_obs_F0_'+str(F0)+'_F0am_'+str(F0_am)+'_v6.pdf',bbox_inches = 'tight')

#figure_drydep.savefig('Figures/offline_dvel_LT_obs_F0_'+str(F0)+'_LC_all_v5.pdf',bbox_inches = 'tight')
# figure_drydep.savefig('Figures/offline_dvel_LT_obs_F0_'+str(F0)+'_LC_all_v4_med.pdf',bbox_inches = 'tight')

